{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CqWhSlfDqmns"
   },
   "source": [
    "<a id='prerequisits'></a>\n",
    "\n",
    "# Prerequisits\n",
    "This section installs `gradslam` (if not already installed), imports the necessary packages for the tutorial, and downloads 'lr kt1' (the first trajectory) of [ICL-NUIM dataset](https://www.doc.ic.ac.uk/~ahanda/VaFRIC/iclnuim.html) and structures it as below:\n",
    "```\n",
    "ICL\n",
    "    living_room_traj1_frei_png\n",
    "        depth/    rgb/    associations.txt    livingRoom1.gt.freiburg    livingRoom1n.gt.sim\n",
    "```\n",
    "\n",
    "\n",
    "We set the ICL path variable: `icl_path='ICL/'`. The ICL data is loaded into the following variables: <br>\n",
    "\n",
    "* `colors`: of shape (batch_size, sequence_length, height, width, 3) <br>\n",
    "* `depths`: of shape (batch_size, sequence_length, height, width, 1) <br>\n",
    "* `intrinsics`: of shape (batch_size, 1, 4, 4) <br>\n",
    "* `poses`: of shape (batch_size, sequence_length, 4, 4) <br>\n",
    "\n",
    "Finally `RGBDImages` is created from the ICL data and visualized."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9opFP8sDqmnt"
   },
   "outputs": [],
   "source": [
    "# install gradslam (if not installed)\n",
    "try:\n",
    "    import gradslam as gs\n",
    "except ImportError:\n",
    "    print(\"Installing gradslam...\")\n",
    "    !pip install 'git+https://github.com/gradslam/gradslam.git' -q\n",
    "    print('Installed')\n",
    "\n",
    "# import necessary packages\n",
    "import gradslam as gs\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import torch\n",
    "from gradslam import Pointclouds, RGBDImages\n",
    "from gradslam.datasets import ICL\n",
    "from gradslam.slam import PointFusion\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "# download 'lr kt1' of ICL dataset\n",
    "if not os.path.isdir('ICL'):\n",
    "    os.mkdir('ICL')\n",
    "if not os.path.isdir('ICL/living_room_traj1_frei_png'):\n",
    "    print('Downloading ICL/living_room_traj1_frei_png dataset...')\n",
    "    os.mkdir('ICL/living_room_traj1_frei_png')\n",
    "    !wget http://www.doc.ic.ac.uk/~ahanda/living_room_traj1_frei_png.tar.gz -P ICL/living_room_traj1_frei_png/ -q\n",
    "    !tar -xzf ICL/living_room_traj1_frei_png/living_room_traj1_frei_png.tar.gz -C ICL/living_room_traj1_frei_png/\n",
    "    !rm ICL/living_room_traj1_frei_png/living_room_traj1_frei_png.tar.gz\n",
    "    !wget https://www.doc.ic.ac.uk/~ahanda/VaFRIC/livingRoom1n.gt.sim -P ICL/living_room_traj1_frei_png/ -q\n",
    "    print('Downloaded.')\n",
    "icl_path = 'ICL/'\n",
    "\n",
    "# load dataset\n",
    "dataset = ICL(icl_path, seqlen=8, height=240, width=320)\n",
    "loader = DataLoader(dataset=dataset, batch_size=2)\n",
    "colors, depths, intrinsics, poses, *_ = next(iter(loader))\n",
    "\n",
    "# create rgbdimages object\n",
    "rgbdimages = RGBDImages(colors, depths, intrinsics, poses)\n",
    "rgbdimages.plotly(0).update_layout(autosize=False, height=600, width=400).show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4aAK9HbIqmnv"
   },
   "source": [
    "# Basic PointFusion\n",
    "\n",
    "> **_NOTE:_**  Make sure to have ran the [prerequisits](#Prerequisits) section before running this section.\n",
    "\n",
    "This section demonstrates the basic use of PointFusion."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xlqE42usqmnv"
   },
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "slam = PointFusion(device=device)\n",
    "pointclouds, recovered_poses = slam(rgbdimages)\n",
    "pointclouds.plotly(0, max_num_points=15000).update_layout(autosize=False, width=600).show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hPIkUUNqqmnx"
   },
   "source": [
    "# Step by step PointFusion\n",
    "\n",
    "> **_NOTE:_**  Make sure to have ran the [prerequisits](#Prerequisits) section before running this section.\n",
    "\n",
    "This section demonstrates building the pointcloud map from one frame at a time by calling the SLAM object's `.step()` method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "warZULSpqmny"
   },
   "outputs": [],
   "source": [
    "# load dataset\n",
    "dataset = ICL(icl_path, seqlen=4, height=240, width=320)\n",
    "loader = DataLoader(dataset=dataset, batch_size=1)\n",
    "colors, depths, intrinsics, poses, *_ = next(iter(loader))\n",
    "\n",
    "# create rgbdimages object\n",
    "rgbdimages = RGBDImages(colors, depths, intrinsics)\n",
    "\n",
    "# step by step SLAM\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "slam = PointFusion(device=device)\n",
    "\n",
    "pointclouds = Pointclouds(device=device)\n",
    "batch_size, seq_len = rgbdimages.shape[:2]\n",
    "initial_poses = torch.eye(4, device=device).view(1, 1, 4, 4).repeat(batch_size, 1, 1, 1)\n",
    "prev_frame = None\n",
    "for s in range(seq_len):\n",
    "    live_frame = rgbdimages[:, s].to(device)\n",
    "    if s == 0 and live_frame.poses is None:\n",
    "        live_frame.poses = initial_poses\n",
    "    pointclouds, live_frame.poses = slam.step(pointclouds, live_frame, prev_frame)\n",
    "    prev_frame = live_frame\n",
    "pointclouds.plotly(0, max_num_points=15000).update_layout(autosize=False, width=600).show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qfpvR58Sqmn0"
   },
   "source": [
    "# Advanced visualization\n",
    "\n",
    "> **_NOTE:_**  Make sure to have ran the [prerequisits](#Prerequisits) section before running this section.\n",
    "\n",
    "This section demonstrates visualization of the pointcloud map as it gets updated from new rgbd frames. It also visualizes the poses with frustums in the 3d map. We use ground truth poses here (`odom=gt`) as the data sequences were fetched with a large dilation value (i.e. small fps) and ICP/gradICP won't work well in this case."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Ie5TjiVtqmn1"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import plotly.graph_objects as go\n",
    "\n",
    "def plotly_map_update_visualization(intermediate_pcs, poses, K, max_points_per_pc=50000, ms_per_frame=50):\n",
    "    \"\"\"\n",
    "    Args:\n",
    "        - intermediate_pcs (List[gradslam.Pointclouds]): list of gradslam.Pointclouds objects, each of batch size 1\n",
    "        - poses (torch.Tensor): poses for drawing frustums\n",
    "        - K (torch.Tensor): Intrinsics matrix\n",
    "        - max_points_per_pc (int): maximum number of points to plot for each pointcloud\n",
    "        - ms_per_frame (int): miliseconds per frame for the animation\n",
    "\n",
    "    Shape:\n",
    "        - poses: :math:`(L, 4, 4)`\n",
    "        - K: :math:`(4, 4)`\n",
    "    \"\"\"\n",
    "    def plotly_poses(poses, K):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            poses (np.ndarray):\n",
    "            K (np.ndarray):\n",
    "\n",
    "        Shapes:\n",
    "            - poses: :math:`(L, 4, 4)`\n",
    "            - K: :math:`(4, 4)`\n",
    "        \"\"\"\n",
    "        fx = abs(K[0, 0])\n",
    "        fy = abs(K[1, 1])\n",
    "        f = (fx + fy) / 2\n",
    "        cx = K[0, 2]\n",
    "        cy = K[1, 2]\n",
    "\n",
    "        cx = cx / f\n",
    "        cy = cy / f\n",
    "        f = 1.\n",
    "\n",
    "        pos_0 = np.array([0., 0., 0.])\n",
    "        fustum_0 = np.array(\n",
    "            [\n",
    "                [-cx, -cy, f],\n",
    "                [cx, -cy, f],\n",
    "                list(pos_0),\n",
    "                [-cx, -cy, f],\n",
    "                [-cx, cy, f],\n",
    "                list(pos_0),\n",
    "                [cx, cy, f],\n",
    "                [-cx, cy, f],\n",
    "                [cx, cy, f],\n",
    "                [cx, -cy, f],\n",
    "            ]\n",
    "        )\n",
    "\n",
    "        traj = []\n",
    "        traj_frustums = []\n",
    "        for pose in poses:\n",
    "            rot = pose[:3, :3]\n",
    "            tvec = pose[:3, 3]\n",
    "\n",
    "            fustum_i = fustum_0 @ rot.T\n",
    "            fustum_i = fustum_i + tvec\n",
    "            pos_i = pos_0 + tvec\n",
    "\n",
    "            pos_i = np.round(pos_i, decimals=2)\n",
    "            fustum_i = np.round(fustum_i, decimals=2)\n",
    "\n",
    "            traj.append(pos_i)\n",
    "            traj_array = np.array(traj)\n",
    "            traj_frustum = [\n",
    "                go.Scatter3d(\n",
    "                    x=fustum_i[:, 0], y=fustum_i[:, 1], z=fustum_i[:, 2],\n",
    "                    marker=dict(\n",
    "                        size=0.1,\n",
    "                    ),\n",
    "                    line=dict(\n",
    "                        color='purple',\n",
    "                        width=4,\n",
    "                    )\n",
    "                ),\n",
    "                go.Scatter3d(\n",
    "                    x=pos_i[None, 0], y=pos_i[None, 1], z=pos_i[None, 2],\n",
    "                    marker=dict(\n",
    "                        size=6.,\n",
    "                        color='purple',\n",
    "                    )\n",
    "                ),\n",
    "                go.Scatter3d(\n",
    "                    x=traj_array[:, 0], y=traj_array[:, 1], z=traj_array[:, 2],\n",
    "                    marker=dict(\n",
    "                        size=0.1,\n",
    "                    ),\n",
    "                    line=dict(\n",
    "                        color='purple',\n",
    "                        width=2,\n",
    "                    )\n",
    "                ),\n",
    "            ]\n",
    "            traj_frustums.append(traj_frustum)\n",
    "        return traj_frustums\n",
    "\n",
    "    def frame_args(duration):\n",
    "        return {\n",
    "            \"frame\": {\"duration\": duration, \"redraw\": True},\n",
    "            \"mode\": \"immediate\",\n",
    "            \"fromcurrent\": True,\n",
    "            \"transition\": {\"duration\": duration, \"easing\": \"linear\"},\n",
    "        }\n",
    "\n",
    "    # visualization\n",
    "    scatter3d_list = [pc.plotly(0, as_figure=False, max_num_points=max_points_per_pc) for pc in intermediate_pcs]\n",
    "    traj_frustums = plotly_poses(poses.cpu().numpy(), K.cpu().numpy())\n",
    "    data = [[*frustum, scatter3d] for frustum, scatter3d in zip(traj_frustums, scatter3d_list)]\n",
    "\n",
    "    steps = [\n",
    "        {\"args\": [[i], frame_args(0)], \"label\": i, \"method\": \"animate\"}\n",
    "        for i in range(seq_len)\n",
    "    ]\n",
    "    sliders = [\n",
    "        {\n",
    "            \"active\": 0,\n",
    "            \"yanchor\": \"top\",\n",
    "            \"xanchor\": \"left\",\n",
    "            \"currentvalue\": {\"prefix\": \"Frame: \"},\n",
    "            \"pad\": {\"b\": 10, \"t\": 60},\n",
    "            \"len\": 0.9,\n",
    "            \"x\": 0.1,\n",
    "            \"y\": 0,\n",
    "            \"steps\": steps,\n",
    "        }\n",
    "    ]\n",
    "    updatemenus = [\n",
    "        {\n",
    "            \"buttons\": [\n",
    "                {\n",
    "                    \"args\": [None, frame_args(ms_per_frame)],\n",
    "                    \"label\": \"&#9654;\",\n",
    "                    \"method\": \"animate\",\n",
    "                },\n",
    "                {\n",
    "                    \"args\": [[None], frame_args(0)],\n",
    "                    \"label\": \"&#9724;\",\n",
    "                    \"method\": \"animate\",\n",
    "                },\n",
    "            ],\n",
    "            \"direction\": \"left\",\n",
    "            \"pad\": {\"r\": 10, \"t\": 70},\n",
    "            \"showactive\": False,\n",
    "            \"type\": \"buttons\",\n",
    "            \"x\": 0.1,\n",
    "            \"xanchor\": \"right\",\n",
    "            \"y\": 0,\n",
    "            \"yanchor\": \"top\",\n",
    "        }\n",
    "    ]\n",
    "\n",
    "    fig = go.Figure()\n",
    "    frames = [{\"data\": frame, \"name\": i} for i, frame in enumerate(data)]\n",
    "    fig.add_traces(frames[0][\"data\"])\n",
    "    fig.update(frames=frames)\n",
    "    fig.update_layout(\n",
    "        updatemenus=updatemenus,\n",
    "        sliders=sliders,\n",
    "        showlegend=False,\n",
    "        scene=dict(\n",
    "            xaxis=dict(showticklabels=False, showgrid=False, zeroline=False, visible=False,),\n",
    "            yaxis=dict(showticklabels=False, showgrid=False, zeroline=False, visible=False,),\n",
    "            zaxis=dict(showticklabels=False, showgrid=False, zeroline=False, visible=False,),\n",
    "        )\n",
    "    )\n",
    "    fig.show()\n",
    "    return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Hmw1TxJrqmn2"
   },
   "outputs": [],
   "source": [
    "dataset = ICL(icl_path, seqlen=8, dilation=19, height=60, width=80)\n",
    "loader = DataLoader(dataset=dataset, batch_size=1)\n",
    "colors, depths, intrinsics, poses, *_ = next(iter(loader))\n",
    "\n",
    "# create rgbdimages object\n",
    "rgbdimages = RGBDImages(colors, depths, intrinsics, poses)\n",
    "\n",
    "# step by step SLAM and store intermediate maps\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "slam = PointFusion(odom='gt', device=device)  # use gt poses because large dilation (small fps) makes ICP difficult\n",
    "pointclouds = Pointclouds(device=device)\n",
    "batch_size, seq_len = rgbdimages.shape[:2]\n",
    "initial_poses = torch.eye(4, device=device).view(1, 1, 4, 4).repeat(batch_size, 1, 1, 1)\n",
    "prev_frame = None\n",
    "intermediate_pcs = []\n",
    "for s in range(seq_len):\n",
    "    live_frame = rgbdimages[:, s].to(device)\n",
    "    if s == 0 and live_frame.poses is None:\n",
    "        live_frame.poses = initial_poses\n",
    "    pointclouds, live_frame.poses = slam.step(pointclouds, live_frame, prev_frame)\n",
    "    prev_frame = live_frame if slam.odom != 'gt' else None\n",
    "    intermediate_pcs.append(pointclouds[0])\n",
    "\n",
    "# visualize\n",
    "rgbdimages.plotly(0).update_layout(autosize=False, height=600, width=400).show()\n",
    "fig = plotly_map_update_visualization(intermediate_pcs, poses[0], intrinsics[0, 0], 10000)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "pointfusion_tutorial.ipynb",
   "provenance": []
  },
  "jupytext": {
   "text_representation": {
    "extension": ".py",
    "format_name": "light",
    "format_version": "1.5",
    "jupytext_version": "1.6.0"
   }
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
